import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers

#（1）Mish激活函数
def mish(x):
    # x*tanh(ln(1+ex))
    x = x * tf.math.tanh(tf.math.softplus(x))
    return x

#（2）标准卷积块
def conv_block(inputs, filters, kernel_size, strides):
    # 卷积+BN+Mish
    x = layers.Conv2D(filters, kernel_size, strides, 
                      padding='same', use_bias=False,  # 有BN不要偏置
                      kernel_regularizer=regularizers.l2(5e-4))(inputs)  # l2正则化

    x = layers.BatchNormalization()(x)
    x = mish(x)

    return x

#（3）残差块
def res_block(inputs, filters, num):

    residual = inputs  # 残差边
    # 1*1卷积调整通道
    x = conv_block(inputs, filters//2, kernel_size=(1,1), strides=1)
    # 3*3卷积提取特征
    x = conv_block(x, filters//2 if num!=1 else filters, kernel_size=(3,3), strides=1)
    # 残差连接输入和输出
    x = layers.Add()([x, residual])
    
    return x

#（4）CSP结构
def csp_bolck(inputs, filters, num):
    
    # 卷积下采样
    x = conv_block(inputs, filters, kernel_size=(3,3), strides=2)
    
    # num!=1时1*1卷积在通道维度上会下降一半
    shortcut = conv_block(x, filters//2 if num!=1 else filters, (1,1), strides=1)  # 残差边
    mainconv = conv_block(x, filters//2 if num!=1 else filters, (1,1), strides=1)  # 主干卷积

    # 重复执行残差结构
    for _ in range(num):
        mainconv = res_block(inputs=mainconv, filters=filters, num=num)

    # 1*1卷积调整通道
    mainconv = conv_block(mainconv, filters//2 if num!=1 else filters, (1,1), strides=1)

    # 输入和输出在通道维度堆叠
    x = layers.concatenate([mainconv, shortcut])

    # 1*1卷积整合通道
    x = conv_block(x, filters, (1,1), strides=1)

    return x

#（5）主干网络
def cspdarknet(inputs):

    # [416,416,3]==>[416,416,32]
    x = conv_block(inputs, filters=32, kernel_size=(3,3), strides=1)
    # [416,416,32]==>[208,208,64]
    x = csp_bolck(x, filters=64, num=1)
    # [208,208,64]==>[104,104,128]
    x = csp_bolck(x, filters=128, num=2)
    
    # [104,104,128]==>[52,52,256]
    x = csp_bolck(x, filters=256, num=8)
    feat1 = x

    # [52,52,256]==>[26,26,512]
    x = csp_bolck(x, filters=512, num=8)
    feat2 = x

    # [26,26,512]==>[13,13,1024]
    x = csp_bolck(x, filters=1024, num=4)
    feat3 = x

    return feat1, feat2, feat3

